{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "965e4d64-71ac-4f23-9db2-6e742ece1da7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install -q zarr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64ac080d-112a-428d-9566-dd97a5ecf98a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join\n",
    "\n",
    "import dask\n",
    "import dask.array as da\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import numba\n",
    "\n",
    "from numba.typed import Dict\n",
    "from numba import prange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed416b30-357c-4e2e-8b0e-aa0cea2fb52f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "PATH = '/mnt/dssfs02/cxg_census/data_2023_05_15'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe375488-7aeb-4045-8220-810dd2ae57c5",
   "metadata": {},
   "source": [
    "# Get idxs for subsampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d7ae22f-901f-4530-a47a-4b24eafc195d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "obs_train = pd.read_parquet(join(PATH, 'train/obs.parquet')).reset_index(drop=True)\n",
    "x_train = da.from_zarr(join(PATH, 'train/zarr'), component='X')\n",
    "\n",
    "obs_val = pd.read_parquet(join(PATH, 'val/obs.parquet')).reset_index(drop=True)\n",
    "x_val = da.from_zarr(join(PATH, 'val/zarr'), component='X')\n",
    "\n",
    "obs_test = pd.read_parquet(join(PATH, 'test/obs.parquet')).reset_index(drop=True)\n",
    "x_test = da.from_zarr(join(PATH, 'test/zarr'), component='X')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7dc528c-971b-40c9-b6cb-aa98507a5e11",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "var = pd.read_parquet(join(PATH, 'train/var.parquet'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "399386bf-69a5-49f5-91ca-c1f716da42ca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for col in obs_train.columns:\n",
    "    if obs_train[col].dtype.name == 'category':\n",
    "        obs_train[col] = obs_train[col].cat.remove_unused_categories()\n",
    "\n",
    "\n",
    "for col in obs_val.columns:\n",
    "    if obs_val[col].dtype.name == 'category':\n",
    "        obs_val[col] = obs_val[col].cat.remove_unused_categories()\n",
    "        \n",
    "\n",
    "for col in obs_test.columns:\n",
    "    if obs_test[col].dtype.name == 'category':\n",
    "        obs_test[col] = obs_test[col].cat.remove_unused_categories()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a22890a-a765-49b2-91dc-f3741f9c8cc4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(seed=1)\n",
    "\n",
    "subset_idxs = {}\n",
    "\n",
    "\n",
    "for split, obs in [('train', obs_train), ('val', obs_val), ('test', obs_test)]:\n",
    "    idx_subset = obs[obs.tissue_general == 'lung'].index.to_numpy()\n",
    "    rng.shuffle(idx_subset)\n",
    "    subset_idxs[split] = idx_subset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "403b29c0-3959-49da-b18e-c7fffec55400",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "subset_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82b31021-3af7-4e38-97f6-36aa52f29dd1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bdb667bf-2a00-4b44-837d-5720f4727e3b",
   "metadata": {},
   "source": [
    "# Store balanced data to disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ded15de8-8238-4bf0-ba19-a45de784d794",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "SAVE_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15_lung_only'\n",
    "CHUNK_SIZE = 16384"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e23192f-3d9b-4f5f-bb71-ddf10b6fe63a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for split, x, obs in [\n",
    "    ('train', x_train, obs_train),\n",
    "    ('val', x_val, obs_val),\n",
    "    ('test', x_test, obs_test)\n",
    "]:\n",
    "    # out-of-order indexing is on purpose here as we want to shuffle the data to break up data sets\n",
    "    X_split = x[subset_idxs[split], :].rechunk((CHUNK_SIZE, -1))\n",
    "    obs_split = obs.iloc[subset_idxs[split], :]\n",
    "\n",
    "    save_dir = join(SAVE_PATH, split)\n",
    "    os.makedirs(save_dir)\n",
    "\n",
    "    var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)\n",
    "    obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)\n",
    "    da.to_zarr(\n",
    "        X_split,\n",
    "        join(save_dir, 'zarr'),\n",
    "        component='X',\n",
    "        compute=True,\n",
    "        compressor='default', \n",
    "        order='C'\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e964f8d5-489e-40b5-a382-3019f72a83fb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
